{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5d60cbb9-2a6a-43ea-a9e9-f67b16ddd2b2",
   "metadata": {},
   "source": [
    "# Tool error handling\n",
    "\n",
    "Using a model to invoke a tool has some obvious potential failure modes. Firstly, the model needs to return a output that can be parsed at all. Secondly, the model needs to return tool arguments that are valid.\n",
    "\n",
    "We can build error handling into our chains to mitigate these failure modes."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "712c774f-27c7-4351-a196-39900ca155f5",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "We'll need to install the following packages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63056c24-9834-4e3d-8bc5-54b1e6c5df86",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install --upgrade --quiet langchain langchain-openai"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68107597-0c8c-4bb5-8c12-9992fabdf71a",
   "metadata": {},
   "source": [
    "And set these environment variables:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08785b6d-722d-4620-b6ec-36deb3842c69",
   "metadata": {},
   "outputs": [],
   "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n",
    "\n",
    "# If you'd like to use LangSmith, uncomment the below:\n",
    "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
    "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a50f93a-5d6f-4691-8f98-27239a1c2f95",
   "metadata": {},
   "source": [
    "## Chain\n",
    "\n",
    "Suppose we have the following (dummy) tool and tool-calling chain. We'll make our tool intentionally convoluted to try and trip up the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1d20604e-c4d1-4d21-841b-23e4f61aec36",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define tool\n",
    "from langchain_core.tools import tool\n",
    "\n",
    "\n",
    "@tool\n",
    "def complex_tool(int_arg: int, float_arg: float, dict_arg: dict) -> int:\n",
    "    \"\"\"Do something complex with a complex tool.\"\"\"\n",
    "    return int_arg * float_arg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "553c2c13-28c8-4451-8a3a-6c31d52dc31d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define model and bind tool\n",
    "from langchain_community.tools.convert_to_openai import format_tool_to_openai_tool\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "model = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
    "model_with_tools = model.bind(\n",
    "    tools=[format_tool_to_openai_tool(complex_tool)],\n",
    "    tool_choice={\"type\": \"function\", \"function\": {\"name\": \"complex_tool\"}},\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "802b2eca-9f79-4d6c-8257-85139ca5c752",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define chain\n",
    "from operator import itemgetter\n",
    "\n",
    "from langchain.output_parsers import JsonOutputKeyToolsParser\n",
    "from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough\n",
    "\n",
    "chain = (\n",
    "    model_with_tools\n",
    "    | JsonOutputKeyToolsParser(key_name=\"complex_tool\", return_single=True)\n",
    "    | complex_tool\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c34f005e-63f0-4841-9461-ca36c36607fc",
   "metadata": {},
   "source": [
    "We can see that when we try to invoke this chain with even a fairly explicit input, the model fails to correctly call the tool (it forgets the `dict_arg` argument)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d354664c-ac44-4967-a35f-8912b3ad9477",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValidationError",
     "evalue": "1 validation error for complex_toolSchemaSchema\ndict_arg\n  field required (type=value_error.missing)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValidationError\u001b[0m                           Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mchain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minvoke\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m      2\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43muse complex tool. the args are 5, 2.1, empty dictionary. don\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mt forget dict_arg\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\n\u001b[1;32m      3\u001b[0m \u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/langchain/libs/core/langchain_core/runnables/base.py:1774\u001b[0m, in \u001b[0;36mRunnableSequence.invoke\u001b[0;34m(self, input, config)\u001b[0m\n\u001b[1;32m   1772\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1773\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m i, step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msteps):\n\u001b[0;32m-> 1774\u001b[0m         \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mstep\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minvoke\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1775\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1776\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;66;43;03m# mark each step as a child run\u001b[39;49;00m\n\u001b[1;32m   1777\u001b[0m \u001b[43m            \u001b[49m\u001b[43mpatch_config\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1778\u001b[0m \u001b[43m                \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrun_manager\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_child\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mseq:step:\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mi\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1779\u001b[0m \u001b[43m            \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1780\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1781\u001b[0m \u001b[38;5;66;03m# finish the root run\u001b[39;00m\n\u001b[1;32m   1782\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n",
      "File \u001b[0;32m~/langchain/libs/core/langchain_core/tools.py:210\u001b[0m, in \u001b[0;36mBaseTool.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m    203\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minvoke\u001b[39m(\n\u001b[1;32m    204\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m    205\u001b[0m     \u001b[38;5;28minput\u001b[39m: Union[\u001b[38;5;28mstr\u001b[39m, Dict],\n\u001b[1;32m    206\u001b[0m     config: Optional[RunnableConfig] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m    207\u001b[0m     \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any,\n\u001b[1;32m    208\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m    209\u001b[0m     config \u001b[38;5;241m=\u001b[39m ensure_config(config)\n\u001b[0;32m--> 210\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    211\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    212\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcallbacks\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    213\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtags\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtags\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    214\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmetadata\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    215\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrun_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrun_name\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    216\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    217\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/langchain/libs/core/langchain_core/tools.py:315\u001b[0m, in \u001b[0;36mBaseTool.run\u001b[0;34m(self, tool_input, verbose, start_color, color, callbacks, tags, metadata, run_name, **kwargs)\u001b[0m\n\u001b[1;32m    301\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun\u001b[39m(\n\u001b[1;32m    302\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m    303\u001b[0m     tool_input: Union[\u001b[38;5;28mstr\u001b[39m, Dict],\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    312\u001b[0m     \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any,\n\u001b[1;32m    313\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m    314\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"Run the tool.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 315\u001b[0m     parsed_input \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_parse_input\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    316\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;129;01mand\u001b[39;00m verbose \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    317\u001b[0m         verbose_ \u001b[38;5;241m=\u001b[39m verbose\n",
      "File \u001b[0;32m~/langchain/libs/core/langchain_core/tools.py:250\u001b[0m, in \u001b[0;36mBaseTool._parse_input\u001b[0;34m(self, tool_input)\u001b[0m\n\u001b[1;32m    248\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    249\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m input_args \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 250\u001b[0m         result \u001b[38;5;241m=\u001b[39m \u001b[43minput_args\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    251\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[1;32m    252\u001b[0m             k: \u001b[38;5;28mgetattr\u001b[39m(result, k)\n\u001b[1;32m    253\u001b[0m             \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m result\u001b[38;5;241m.\u001b[39mdict()\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m    254\u001b[0m             \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m tool_input\n\u001b[1;32m    255\u001b[0m         }\n\u001b[1;32m    256\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m tool_input\n",
      "File \u001b[0;32m~/langchain/.venv/lib/python3.9/site-packages/pydantic/v1/main.py:526\u001b[0m, in \u001b[0;36mBaseModel.parse_obj\u001b[0;34m(cls, obj)\u001b[0m\n\u001b[1;32m    524\u001b[0m         exc \u001b[38;5;241m=\u001b[39m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m expected dict not \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mobj\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    525\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m ValidationError([ErrorWrapper(exc, loc\u001b[38;5;241m=\u001b[39mROOT_KEY)], \u001b[38;5;28mcls\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[0;32m--> 526\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/langchain/.venv/lib/python3.9/site-packages/pydantic/v1/main.py:341\u001b[0m, in \u001b[0;36mBaseModel.__init__\u001b[0;34m(__pydantic_self__, **data)\u001b[0m\n\u001b[1;32m    339\u001b[0m values, fields_set, validation_error \u001b[38;5;241m=\u001b[39m validate_model(__pydantic_self__\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m, data)\n\u001b[1;32m    340\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m validation_error:\n\u001b[0;32m--> 341\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m validation_error\n\u001b[1;32m    342\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    343\u001b[0m     object_setattr(__pydantic_self__, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__dict__\u001b[39m\u001b[38;5;124m'\u001b[39m, values)\n",
      "\u001b[0;31mValidationError\u001b[0m: 1 validation error for complex_toolSchemaSchema\ndict_arg\n  field required (type=value_error.missing)"
     ]
    }
   ],
   "source": [
    "chain.invoke(\n",
    "    \"use complex tool. the args are 5, 2.1, empty dictionary. don't forget dict_arg\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "890d989d-2d39-4571-9a55-d3496b9b5d27",
   "metadata": {},
   "source": [
    "## Try/except tool call\n",
    "\n",
    "The simplest way to more gracefully handle errors is to try/except the tool-calling step and return a helpful message on errors:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8fedb550-683d-45ae-8876-ae7acb332019",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any\n",
    "\n",
    "from langchain_core.runnables import RunnableConfig\n",
    "\n",
    "\n",
    "def try_except_tool(tool_args: dict, config: RunnableConfig) -> Runnable:\n",
    "    try:\n",
    "        complex_tool.invoke(tool_args, config=config)\n",
    "    except Exception as e:\n",
    "        return f\"Calling tool with arguments:\\n\\n{tool_args}\\n\\nraised the following error:\\n\\n{type(e)}: {e}\"\n",
    "\n",
    "\n",
    "chain = (\n",
    "    model_with_tools\n",
    "    | JsonOutputKeyToolsParser(key_name=\"complex_tool\", return_single=True)\n",
    "    | try_except_tool\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "71a2c98d-c0be-4c0a-bb3d-41ad4596526c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calling tool with arguments:\n",
      "\n",
      "{'int_arg': 5, 'float_arg': 2.1}\n",
      "\n",
      "raised the following error:\n",
      "\n",
      "<class 'pydantic.v1.error_wrappers.ValidationError'>: 1 validation error for complex_toolSchemaSchema\n",
      "dict_arg\n",
      "  field required (type=value_error.missing)\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    chain.invoke(\n",
    "        \"use complex tool. the args are 5, 2.1, empty dictionary. don't forget dict_arg\"\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b2f6393-cb47-49d0-921c-09550a049fe4",
   "metadata": {},
   "source": [
    "## Fallbacks\n",
    "\n",
    "We can also try to fallback to a better model in the event of a tool invocation error. In this case we'll fall back to an identical chain that uses `gpt-4-1106-preview` instead of `gpt-3.5-turbo`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "02cc4223-35fa-4240-976a-012299ca703c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10.5"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chain = (\n",
    "    model_with_tools\n",
    "    | JsonOutputKeyToolsParser(key_name=\"complex_tool\", return_single=True)\n",
    "    | complex_tool\n",
    ")\n",
    "better_model = ChatOpenAI(model=\"gpt-4-1106-preview\", temperature=0).bind(\n",
    "    tools=[format_tool_to_openai_tool(complex_tool)],\n",
    "    tool_choice={\"type\": \"function\", \"function\": {\"name\": \"complex_tool\"}},\n",
    ")\n",
    "better_chain = (\n",
    "    better_model\n",
    "    | JsonOutputKeyToolsParser(key_name=\"complex_tool\", return_single=True)\n",
    "    | complex_tool\n",
    ")\n",
    "\n",
    "chain_with_fallback = chain.with_fallbacks([better_chain])\n",
    "chain_with_fallback.invoke(\n",
    "    \"use complex tool. the args are 5, 2.1, empty dictionary. don't forget dict_arg\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "412f8c4e-cc83-4d87-84a1-5ba2f8edb1e9",
   "metadata": {},
   "source": [
    "Looking at the [Langsmith trace](https://smith.langchain.com/public/241e1266-8555-4d49-99dc-b8df46109c39/r) for this chain run, we can see that the first chain call fails as expected and it's the fallback that succeeds."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "304b59cd-cd25-4205-9769-36595c8f3b59",
   "metadata": {},
   "source": [
    "## Retry with exception\n",
    "\n",
    "To take things one step further, we can try to automatically re-run the chain with the exception passed in, so that the model may be able to correct its behavior:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b5659956-9454-468a-9753-a3ff9052b8f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from typing import Any\n",
    "\n",
    "from langchain_core.messages import AIMessage, HumanMessage, ToolMessage\n",
    "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "\n",
    "\n",
    "class CustomToolException(Exception):\n",
    "    \"\"\"Custom LangChain tool exception.\"\"\"\n",
    "\n",
    "    def __init__(self, tool_call: dict, exception: Exception) -> None:\n",
    "        super().__init__()\n",
    "        self.tool_call = tool_call\n",
    "        self.exception = exception\n",
    "\n",
    "\n",
    "def tool_custom_exception(tool_call: dict, config: RunnableConfig) -> Runnable:\n",
    "    try:\n",
    "        return complex_tool.invoke(tool_call[\"args\"], config=config)\n",
    "    except Exception as e:\n",
    "        raise CustomToolException(tool_call, e)\n",
    "\n",
    "\n",
    "def exception_to_messages(inputs: dict) -> dict:\n",
    "    exception = inputs.pop(\"exception\")\n",
    "    tool_call = {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"complex_tool\",\n",
    "            \"arguments\": json.dumps(exception.tool_call[\"args\"]),\n",
    "        },\n",
    "        \"id\": exception.tool_call[\"id\"],\n",
    "    }\n",
    "\n",
    "    # Add historical messages to the original input, so the model knows that it made a mistake with the last tool call.\n",
    "    messages = [\n",
    "        AIMessage(content=\"\", additional_kwargs={\"tool_calls\": [tool_call]}),\n",
    "        ToolMessage(tool_call_id=tool_call[\"id\"], content=str(exception.exception)),\n",
    "        HumanMessage(\n",
    "            content=\"The last tool calls raised exceptions. Try calling the tools again with corrected arguments.\"\n",
    "        ),\n",
    "    ]\n",
    "    inputs[\"last_output\"] = messages\n",
    "    return inputs\n",
    "\n",
    "\n",
    "# We add a last_output MessagesPlaceholder to our prompt which if not passed in doesn't\n",
    "# affect the prompt at all, but gives us the option to insert an arbitrary list of Messages\n",
    "# into the prompt if needed. We'll use this on retries to insert the error message.\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [(\"human\", \"{input}\"), MessagesPlaceholder(\"last_output\", optional=True)]\n",
    ")\n",
    "chain = (\n",
    "    prompt\n",
    "    | model_with_tools\n",
    "    | JsonOutputKeyToolsParser(\n",
    "        key_name=\"complex_tool\", return_id=True, return_single=True\n",
    "    )\n",
    "    | tool_custom_exception\n",
    ")\n",
    "\n",
    "# If the initial chain call fails, we rerun it withe the exception passed in as a message.\n",
    "self_correcting_chain = chain.with_fallbacks(\n",
    "    [exception_to_messages | chain], exception_key=\"exception\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4c45f5bd-cbb4-47d5-b4b6-aec50673c750",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10.5"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "self_correcting_chain.invoke(\n",
    "    {\n",
    "        \"input\": \"use complex tool. the args are 5, 2.1, empty dictionary. don't forget dict_arg\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50d269a9-3cab-4a37-ba2f-805296453627",
   "metadata": {},
   "source": [
    "And our chain succeeds! Looking at the [LangSmith trace](https://smith.langchain.com/public/b780b740-daf5-43aa-a217-6d4600aba41b/r), we can see that indeed our initial chain still fails, and it's only on retrying that the chain succeeds."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "poetry-venv",
   "language": "python",
   "name": "poetry-venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
